import time
import unittest
from unittest.mock import patch

from sqlglot import Parser, exp, parse, parse_one
from sqlglot.errors import ErrorLevel, ParseError
from sqlglot.parser import logger as parser_logger
from tests.helpers import assert_logger_contains


class TestParser(unittest.TestCase):
    def test_parse_empty(self):
        with self.assertRaises(ParseError):
            parse_one("")

    def test_parse_into(self):
        self.assertIsInstance(parse_one("(1)", into=exp.Tuple), exp.Tuple)
        self.assertIsInstance(parse_one("(1,)", into=exp.Tuple), exp.Tuple)
        self.assertIsInstance(parse_one("(x=1)", into=exp.Tuple), exp.Tuple)

        self.assertIsInstance(parse_one("select * from t", into=exp.Select), exp.Select)
        self.assertIsInstance(parse_one("select * from t limit 5", into=exp.Select), exp.Select)
        self.assertIsInstance(parse_one("left join foo", into=exp.Join), exp.Join)
        self.assertIsInstance(parse_one("int", into=exp.DataType), exp.DataType)
        self.assertIsInstance(parse_one("array<int>", into=exp.DataType), exp.DataType)
        self.assertIsInstance(parse_one("foo", into=exp.Table), exp.Table)
        self.assertIsInstance(
            parse_one(
                "WHEN MATCHED THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary)",
                into=exp.Whens,
            ),
            exp.Whens,
        )

        with self.assertRaises(ParseError) as ctx:
            parse_one("SELECT * FROM tbl", into=exp.Table)

        self.assertEqual(
            str(ctx.exception),
            "Failed to parse 'SELECT * FROM tbl' into <class 'sqlglot.expressions.Table'>",
        )

        self.assertIsInstance(parse_one("foo INT NOT NULL", into=exp.ColumnDef), exp.ColumnDef)

    def test_parse_into_error(self):
        expected_message = "Failed to parse 'SELECT 1;' into [<class 'sqlglot.expressions.From'>]"
        expected_errors = [
            {
                "description": "Invalid expression / Unexpected token",
                "line": 1,
                "col": 6,
                "start_context": "",
                "highlight": "SELECT",
                "end_context": " 1;",
                "into_expression": exp.From,
            }
        ]
        with self.assertRaises(ParseError) as ctx:
            parse_one("SELECT 1;", read="sqlite", into=[exp.From])

        self.assertEqual(str(ctx.exception), expected_message)
        self.assertEqual(ctx.exception.errors, expected_errors)

    def test_parse_into_errors(self):
        expected_message = "Failed to parse 'SELECT 1;' into [<class 'sqlglot.expressions.From'>, <class 'sqlglot.expressions.Join'>]"
        expected_errors = [
            {
                "description": "Invalid expression / Unexpected token",
                "line": 1,
                "col": 6,
                "start_context": "",
                "highlight": "SELECT",
                "end_context": " 1;",
                "into_expression": exp.From,
            },
            {
                "description": "Invalid expression / Unexpected token",
                "line": 1,
                "col": 6,
                "start_context": "",
                "highlight": "SELECT",
                "end_context": " 1;",
                "into_expression": exp.Join,
            },
        ]
        with self.assertRaises(ParseError) as ctx:
            parse_one("SELECT 1;", "sqlite", into=[exp.From, exp.Join])

        self.assertEqual(str(ctx.exception), expected_message)
        self.assertEqual(ctx.exception.errors, expected_errors)

    def test_column(self):
        columns = parse_one("select a, ARRAY[1] b, case when 1 then 1 end").find_all(exp.Column)
        assert len(list(columns)) == 1

        self.assertIsNotNone(parse_one("date").find(exp.Column))

    def test_tuple(self):
        parse_one("(a,)").assert_is(exp.Tuple)

    def test_structs(self):
        cast = parse_one("cast(x as struct<int>)")
        self.assertIsInstance(cast.to.expressions[0], exp.DataType)
        self.assertEqual(cast.sql(), "CAST(x AS STRUCT<INT>)")

        cast = parse_one("cast(x as struct<varchar(10)>)")
        self.assertIsInstance(cast.to.expressions[0], exp.DataType)
        self.assertEqual(cast.sql(), "CAST(x AS STRUCT<VARCHAR(10)>)")

    def test_float(self):
        self.assertEqual(parse_one(".2"), parse_one("0.2"))

    def test_unnest(self):
        unnest_sql = "UNNEST(foo)"
        expr = parse_one(unnest_sql)
        self.assertIsInstance(expr, exp.Unnest)
        self.assertIsInstance(expr.expressions, list)
        self.assertEqual(expr.sql(), unnest_sql)

    def test_unnest_projection(self):
        expr = parse_one("SELECT foo IN UNNEST(bla) AS bar")
        self.assertIsInstance(expr.selects[0], exp.Alias)
        self.assertEqual(expr.selects[0].output_name, "bar")
        self.assertIsNotNone(parse_one("select unnest(x)").find(exp.Unnest))

    def test_unary_plus(self):
        self.assertEqual(parse_one("+15"), exp.Literal.number(15))

    def test_table(self):
        tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)]
        self.assertEqual(set(tables), {"a", "b.c", "d"})

    def test_union(self):
        self.assertIsInstance(parse_one("SELECT * FROM (SELECT 1) UNION SELECT 2"), exp.Union)
        self.assertIsInstance(
            parse_one("SELECT x FROM y HAVING x > (SELECT 1) UNION SELECT 2"), exp.Union
        )

        # Check that modifiers are attached to the topmost union node and not the rightmost query
        single_union = "SELECT x FROM t1 UNION ALL SELECT x FROM t2 LIMIT 1"
        expr = parse_one(single_union)
        limit = expr.assert_is(exp.Union).args.get("limit")
        self.assertIsInstance(limit, exp.Limit)
        self.assertEqual(expr.sql(), single_union)

        two_unions = (
            "SELECT x FROM t1 UNION ALL SELECT x FROM t2 UNION ALL SELECT x FROM t3 LIMIT 1"
        )
        expr = parse_one(two_unions)
        limit = expr.assert_is(exp.Union).args.get("limit")
        self.assertIsInstance(limit, exp.Limit)
        self.assertEqual(expr.sql(), two_unions)

        expr = parse_one(single_union, read="clickhouse")
        self.assertIsNone(expr.args.get("limit"))
        self.assertEqual(expr.sql(dialect="clickhouse"), single_union)

    def test_select(self):
        self.assertIsNotNone(parse_one("select 1 natural"))
        self.assertIsNotNone(parse_one("select * from (select 1) x order by x.y").args["order"])
        self.assertIsNotNone(
            parse_one("select * from x where a = (select 1) order by x.y").args["order"]
        )
        self.assertEqual(len(parse_one("select * from (select 1) x cross join y").args["joins"]), 1)
        self.assertEqual(
            parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(),
            """SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""",
        )
        self.assertIsNone(
            parse_one("create table a as (select b from c) index").find(exp.TableAlias)
        )

    def test_command(self):
        with self.assertLogs(parser_logger) as cm:
            expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1", read="hive")
            self.assertEqual(len(expressions), 3)
            self.assertEqual(expressions[0].sql(), "SET x = 1")
            self.assertEqual(expressions[1].sql(), "ADD JAR s3://a")
            self.assertEqual(expressions[2].sql(), "SELECT 1")

        assert "'ADD JAR s3://a'" in cm.output[0]

    def test_lambda_struct(self):
        expression = parse_one("FILTER(a.b, x -> x.id = id)")
        lambda_expr = expression.expression

        self.assertIsInstance(lambda_expr.this.this, exp.Dot)
        self.assertEqual(lambda_expr.sql(), "x -> x.id = id")

        self.assertIsNone(parse_one("FILTER([], x -> x)").find(exp.Column))

    def test_transactions(self):
        expression = parse_one("BEGIN TRANSACTION")
        self.assertIsNone(expression.this)
        self.assertEqual(expression.args["modes"], [])
        self.assertEqual(expression.sql(), "BEGIN")

        expression = parse_one("START TRANSACTION", read="mysql")
        self.assertIsNone(expression.this)
        self.assertEqual(expression.args["modes"], [])
        self.assertEqual(expression.sql(), "BEGIN")

        expression = parse_one("BEGIN DEFERRED TRANSACTION")
        self.assertEqual(expression.this, "DEFERRED")
        self.assertEqual(expression.args["modes"], [])
        self.assertEqual(expression.sql(), "BEGIN")

        expression = parse_one(
            "START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE", read="presto"
        )
        self.assertIsNone(expression.this)
        self.assertEqual(expression.args["modes"][0], "READ WRITE")
        self.assertEqual(expression.args["modes"][1], "ISOLATION LEVEL SERIALIZABLE")
        self.assertEqual(expression.sql(), "BEGIN READ WRITE, ISOLATION LEVEL SERIALIZABLE")

        expression = parse_one("BEGIN", read="bigquery")
        self.assertNotIsInstance(expression, exp.Transaction)
        self.assertIsNone(expression.expression)
        self.assertEqual(expression.sql(), "BEGIN")

    def test_identify(self):
        expression = parse_one(
            """
            SELECT a, "b", c AS c, d AS "D", e AS "y|z'"
            FROM y."z"
        """
        )

        assert expression.expressions[0].name == "a"
        assert expression.expressions[1].name == "b"
        assert expression.expressions[2].alias == "c"
        assert expression.expressions[3].alias == "D"
        assert expression.expressions[4].alias == "y|z'"
        table = expression.args["from_"].this
        assert table.name == "z"
        assert table.args["db"].name == "y"

    def test_multi(self):
        expressions = parse(
            """
            SELECT * FROM a; SELECT * FROM b;
        """
        )

        assert len(expressions) == 2
        assert expressions[0].args["from_"].name == "a"
        assert expressions[1].args["from_"].name == "b"

        expressions = parse("SELECT 1; ; SELECT 2")

        assert len(expressions) == 3
        assert expressions[1] is None

    def test_expression(self):
        ignore = Parser(error_level=ErrorLevel.IGNORE)
        self.assertIsInstance(ignore.expression(exp.Hint, expressions=[]), exp.Hint)
        self.assertIsInstance(ignore.expression(exp.Hint, y=""), exp.Hint)
        self.assertIsInstance(ignore.expression(exp.Hint), exp.Hint)

        default = Parser(error_level=ErrorLevel.RAISE)
        with self.assertRaises(TypeError):
            default.expression(exp.Hint, y="")
        self.assertIsInstance(default.expression(exp.Hint, expressions=[]), exp.Hint)
        default.expression(exp.Hint)
        self.assertEqual(len(default.errors), 2)

        warn = Parser(error_level=ErrorLevel.WARN)
        warn.expression(exp.Hint)
        self.assertEqual(len(warn.errors), 1)

    def test_parse_errors(self):
        with self.assertRaises(ParseError):
            parse_one("IF(a > 0, a, b, c)")

        with self.assertRaises(ParseError):
            parse_one("IF(a > 0)")

        with self.assertRaises(ParseError):
            parse_one("SELECT CASE FROM x")

        with self.assertRaises(ParseError):
            parse_one("WITH cte AS (SELECT * FROM x)")

        with self.assertRaises(ParseError):
            parse_one("SELECT foo( FROM bar")

        self.assertEqual(
            parse_one(
                "CREATE TABLE t (i UInt8) ENGINE = AggregatingMergeTree() ORDER BY tuple()",
                read="clickhouse",
                error_level=ErrorLevel.RAISE,
            ).sql(dialect="clickhouse"),
            "CREATE TABLE t (i UInt8) ENGINE=AggregatingMergeTree() ORDER BY tuple()",
        )

        with self.assertRaises(ParseError):
            parse_one("SELECT A[:")

        self.assertEqual(parse_one("as as", error_level=ErrorLevel.IGNORE).sql(), "AS as")

    def test_space(self):
        self.assertEqual(
            parse_one("SELECT ROW() OVER(PARTITION  BY x) FROM x GROUP  BY y").sql(),
            "SELECT ROW() OVER (PARTITION BY x) FROM x GROUP BY y",
        )

        self.assertEqual(
            parse_one(
                """SELECT   * FROM x GROUP
                BY y"""
            ).sql(),
            "SELECT * FROM x GROUP BY y",
        )

    def test_missing_by(self):
        with self.assertRaises(ParseError):
            parse_one("SELECT FROM x ORDER BY")

    def test_parameter(self):
        self.assertEqual(parse_one("SELECT @x, @@x, @1").sql(), "SELECT @x, @@x, @1")

    def test_var(self):
        self.assertIsInstance(parse_one("INTERVAL '1' DAY").args["unit"], exp.Var)
        self.assertEqual(parse_one("SELECT @JOIN, @'foo'").sql(), "SELECT @JOIN, @'foo'")

    def test_comments_select(self):
        expression = parse_one(
            """
            --comment1.1
            --comment1.2
            SELECT /*comment1.3*/
                a, --comment2
                b as B, --comment3:testing
                "test--annotation",
                c, --comment4 --foo
                e, --
                f -- space
            FROM foo
            """
        )

        self.assertEqual(expression.comments, ["comment1.1", "comment1.2", "comment1.3"])
        self.assertEqual(expression.expressions[0].comments, ["comment2"])
        self.assertEqual(expression.expressions[1].comments, ["comment3:testing"])
        self.assertEqual(expression.expressions[2].comments, None)
        self.assertEqual(expression.expressions[3].comments, ["comment4 --foo"])
        self.assertEqual(expression.expressions[4].comments, [""])
        self.assertEqual(expression.expressions[5].comments, [" space"])

        expression = parse_one(
            """
            SELECT a.column_name --# Comment 1
                   ,b.column_name2, --# Comment 2
                   b.column_name3 AS NAME3 --# Comment 3
            FROM table_name a
            JOIN table_name2 b ON a.column_name = b.column_name
            """
        )

        self.assertEqual(expression.expressions[0].comments, ["# Comment 1"])
        self.assertEqual(expression.expressions[1].comments, ["# Comment 2"])
        self.assertEqual(expression.expressions[2].comments, ["# Comment 3"])

    def test_comments_select_cte(self):
        expression = parse_one(
            """
            /*comment1.1*/
            /*comment1.2*/
            WITH a AS (SELECT 1)
            SELECT /*comment2*/
                a.*
            FROM /*comment3*/
                a
            """
        )

        self.assertEqual(expression.comments, ["comment2"])
        self.assertEqual(expression.args.get("from_").comments, ["comment3"])
        self.assertEqual(expression.args.get("with_").comments, ["comment1.1", "comment1.2"])

    def test_comments_insert(self):
        expression = parse_one(
            """
            --comment1.1
            --comment1.2
            INSERT INTO /*comment1.3*/
                x       /*comment2*/
            VALUES      /*comment3*/
                (1, 'a', 2.0)
        """
        )

        self.assertEqual(expression.comments, ["comment1.1", "comment1.2", "comment1.3"])
        self.assertEqual(expression.this.comments, ["comment2"])

    def test_comments_insert_cte(self):
        expression = parse_one(
            """
            /*comment1.1*/
            /*comment1.2*/
            WITH a AS (SELECT 1)
            INSERT INTO /*comment2*/
                b /*comment3*/
            SELECT * FROM a
        """
        )

        self.assertEqual(expression.comments, ["comment2"])
        self.assertEqual(expression.this.comments, ["comment3"])
        self.assertEqual(expression.args.get("with_").comments, ["comment1.1", "comment1.2"])

    def test_comments_update(self):
        expression = parse_one(
            """
            --comment1.1
            --comment1.2
            UPDATE  /*comment1.3*/
                tbl /*comment2*/
            SET     /*comment3*/
                x = 2
            WHERE /*comment4*/
                x <> 2
        """
        )

        self.assertEqual(expression.comments, ["comment1.1", "comment1.2", "comment1.3"])
        self.assertEqual(expression.this.comments, ["comment2"])
        self.assertEqual(expression.args.get("where").comments, ["comment4"])

    def test_comments_update_cte(self):
        expression = parse_one(
            """
            /*comment1.1*/
            /*comment1.2*/
            WITH a AS (SELECT * FROM b)
            UPDATE /*comment2*/
                a  /*comment3*/
            SET col = 1
        """
        )

        self.assertEqual(expression.comments, ["comment2"])
        self.assertEqual(expression.this.comments, ["comment3"])
        self.assertEqual(expression.args.get("with_").comments, ["comment1.1", "comment1.2"])

    def test_comments_delete(self):
        expression = parse_one(
            """
            --comment1.1
            --comment1.2
            DELETE /*comment1.3*/
            FROM   /*comment2*/
                x  /*comment3*/
            WHERE  /*comment4*/
                y > 1
        """
        )

        self.assertEqual(expression.comments, ["comment1.1", "comment1.2", "comment1.3"])
        self.assertEqual(expression.this.comments, ["comment3"])
        self.assertEqual(expression.args.get("where").comments, ["comment4"])

    def test_comments_delete_cte(self):
        expression = parse_one(
            """
            /*comment1.1*/
            /*comment1.2*/
            WITH a AS (SELECT * FROM b)
            --comment2
            DELETE FROM a /*comment3*/
        """
        )

        self.assertEqual(expression.comments, ["comment2"])
        self.assertEqual(expression.this.comments, ["comment3"])
        self.assertEqual(expression.args["with_"].comments, ["comment1.1", "comment1.2"])

    def test_type_literals(self):
        self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)"))
        self.assertEqual(parse_one("int.5"), parse_one("CAST(0.5 AS INT)"))
        self.assertEqual(
            parse_one("TIMESTAMP '2022-01-01'").sql(), "CAST('2022-01-01' AS TIMESTAMP)"
        )
        self.assertEqual(
            parse_one("TIMESTAMP(1) '2022-01-01'").sql(), "CAST('2022-01-01' AS TIMESTAMP(1))"
        )
        self.assertEqual(
            parse_one("TIMESTAMP WITH TIME ZONE '2022-01-01'").sql(),
            "CAST('2022-01-01' AS TIMESTAMPTZ)",
        )
        self.assertEqual(
            parse_one("TIMESTAMP WITH LOCAL TIME ZONE '2022-01-01'").sql(),
            "CAST('2022-01-01' AS TIMESTAMPLTZ)",
        )
        self.assertEqual(
            parse_one("TIMESTAMP WITHOUT TIME ZONE '2022-01-01'").sql(),
            "CAST('2022-01-01' AS TIMESTAMP)",
        )
        self.assertEqual(
            parse_one("TIMESTAMP(1) WITH TIME ZONE '2022-01-01'").sql(),
            "CAST('2022-01-01' AS TIMESTAMPTZ(1))",
        )
        self.assertEqual(
            parse_one("TIMESTAMP(1) WITH LOCAL TIME ZONE '2022-01-01'").sql(),
            "CAST('2022-01-01' AS TIMESTAMPLTZ(1))",
        )
        self.assertEqual(
            parse_one("TIMESTAMP(1) WITHOUT TIME ZONE '2022-01-01'").sql(),
            "CAST('2022-01-01' AS TIMESTAMP(1))",
        )
        self.assertEqual(parse_one("TIMESTAMP(1) WITH TIME ZONE").sql(), "TIMESTAMPTZ(1)")
        self.assertEqual(parse_one("TIMESTAMP(1) WITH LOCAL TIME ZONE").sql(), "TIMESTAMPLTZ(1)")
        self.assertEqual(parse_one("TIMESTAMP(1) WITHOUT TIME ZONE").sql(), "TIMESTAMP(1)")
        self.assertEqual(parse_one("""JSON '{"x":"y"}'""").sql(), """PARSE_JSON('{"x":"y"}')""")
        self.assertIsInstance(parse_one("TIMESTAMP(1)"), exp.Func)
        self.assertIsInstance(parse_one("TIMESTAMP('2022-01-01')"), exp.Func)
        self.assertIsInstance(parse_one("TIMESTAMP()"), exp.Func)
        self.assertIsInstance(parse_one("map.x"), exp.Column)
        self.assertIsInstance(parse_one("CAST(x AS CHAR(5))").to.expressions[0], exp.DataTypeParam)
        self.assertEqual(parse_one("1::int64", dialect="bigquery"), parse_one("CAST(1 AS BIGINT)"))

    def test_set_expression(self):
        set_ = parse_one("SET")

        self.assertEqual(set_.sql(), "SET")
        self.assertIsInstance(set_, exp.Set)

        set_session = parse_one("SET SESSION x = 1")

        self.assertEqual(set_session.sql(), "SET SESSION x = 1")
        self.assertIsInstance(set_session, exp.Set)

        set_item = set_session.expressions[0]

        self.assertIsInstance(set_item, exp.SetItem)
        self.assertIsInstance(set_item.this, exp.EQ)
        self.assertIsInstance(set_item.this.this, exp.Column)
        self.assertIsInstance(set_item.this.expression, exp.Literal)

        self.assertEqual(set_item.args.get("kind"), "SESSION")

        set_to = parse_one("SET x TO 1")

        self.assertEqual(set_to.sql(), "SET x = 1")
        self.assertIsInstance(set_to, exp.Set)

        with self.assertLogs(parser_logger) as cm:
            set_as_command = parse_one("SET DEFAULT ROLE ALL TO USER")
            assert "'SET DEFAULT ROLE ALL TO USER'" in cm.output[0]

        self.assertEqual(set_as_command.sql(), "SET DEFAULT ROLE ALL TO USER")

        self.assertIsInstance(set_as_command, exp.Command)
        self.assertEqual(set_as_command.this, "SET")
        self.assertEqual(set_as_command.expression, " DEFAULT ROLE ALL TO USER")

    def test_pretty_config_override(self):
        self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x")
        with patch("sqlglot.pretty", True):
            self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT\n  col\nFROM x")

        self.assertEqual(parse_one("SELECT col FROM x").sql(pretty=True), "SELECT\n  col\nFROM x")

    @patch("sqlglot.parser.logger")
    def test_comment_error_n(self, logger):
        parse_one(
            """SUM
(
-- test
)""",
            error_level=ErrorLevel.WARN,
        )

        assert_logger_contains(
            "Required keyword: 'this' missing for <class 'sqlglot.expressions.Sum'>. Line 4, Col: 1.",
            logger,
        )

    @patch("sqlglot.parser.logger")
    def test_comment_error_r(self, logger):
        parse_one(
            """SUM(-- test\r)""",
            error_level=ErrorLevel.WARN,
        )

        assert_logger_contains(
            "Required keyword: 'this' missing for <class 'sqlglot.expressions.Sum'>. Line 2, Col: 1.",
            logger,
        )

    @patch("sqlglot.parser.logger")
    def test_create_table_error(self, logger):
        parse_one(
            """CREATE TABLE SELECT""",
            error_level=ErrorLevel.WARN,
        )

        assert_logger_contains(
            "Expected table name",
            logger,
        )

    def test_pivot_columns(self):
        nothing_aliased = """
            SELECT * FROM (
                SELECT partname, price FROM part
            ) PIVOT (AVG(price) FOR partname IN ('prop', 'rudder'))
        """

        everything_aliased = """
            SELECT * FROM (
                SELECT partname, price FROM part
            ) PIVOT (AVG(price) AS avg_price FOR partname IN ('prop' AS prop1, 'rudder' AS rudder1))
        """

        only_pivot_columns_aliased = """
            SELECT * FROM (
                SELECT partname, price FROM part
            ) PIVOT (AVG(price) FOR partname IN ('prop' AS prop1, 'rudder' AS rudder1))
        """

        columns_partially_aliased = """
            SELECT * FROM (
                SELECT partname, price FROM part
            ) PIVOT (AVG(price) FOR partname IN ('prop' AS prop1, 'rudder'))
        """

        multiple_aggregates_aliased = """
            SELECT * FROM (
                SELECT partname, price, quality FROM part
            ) PIVOT (AVG(price) AS p, MAX(quality) AS q FOR partname IN ('prop' AS prop1, 'rudder'))
        """

        multiple_aggregates_not_aliased = """
            SELECT * FROM (
                SELECT partname, price, quality FROM part
            ) PIVOT (AVG(price), MAX(quality) FOR partname IN ('prop' AS prop1, 'rudder'))
        """

        multiple_aggregates_not_aliased_with_quoted_identifier_spark = """
            SELECT * FROM (
                SELECT partname, price, quality FROM part
            ) PIVOT (AVG(`PrIcE`), MAX(quality) FOR partname IN ('prop' AS prop1, 'rudder'))
        """

        multiple_aggregates_not_aliased_with_quoted_identifier_duckdb = """
            SELECT * FROM (
                SELECT partname, price, quality FROM part
            ) PIVOT (AVG("PrIcE"), MAX(quality) FOR partname IN ('prop' AS prop1, 'rudder'))
        """

        two_in_clauses_duckdb = """
            SELECT * FROM cities PIVOT (
                sum(population) AS total,
                count(population) AS count
                FOR
                    year IN (2000, 2010)
                    country IN ('NL', 'US')
            )
        """

        three_in_clauses_duckdb = """
            SELECT * FROM cities PIVOT (
                sum(population) AS total,
                count(population) AS count
                FOR
                    year IN (2000, 2010)
                    country IN ('NL', 'US')
                    name IN ('Amsterdam', 'Seattle')
            )
        """

        query_to_column_names = {
            nothing_aliased: {
                "bigquery": ["prop", "rudder"],
                "duckdb": ["prop", "rudder"],
                "redshift": ["prop", "rudder"],
                "snowflake": ['''"'prop'"''', '''"'rudder'"'''],
                "spark": ["prop", "rudder"],
            },
            everything_aliased: {
                "bigquery": ["avg_price_prop1", "avg_price_rudder1"],
                "duckdb": ["prop1_avg_price", "rudder1_avg_price"],
                "redshift": ["prop1_avg_price", "rudder1_avg_price"],
                "spark": ["prop1", "rudder1"],
            },
            only_pivot_columns_aliased: {
                "bigquery": ["prop1", "rudder1"],
                "duckdb": ["prop1", "rudder1"],
                "redshift": ["prop1", "rudder1"],
                "spark": ["prop1", "rudder1"],
            },
            columns_partially_aliased: {
                "bigquery": ["prop1", "rudder"],
                "duckdb": ["prop1", "rudder"],
                "redshift": ["prop1", "rudder"],
                "spark": ["prop1", "rudder"],
            },
            multiple_aggregates_aliased: {
                "bigquery": ["p_prop1", "q_prop1", "p_rudder", "q_rudder"],
                "duckdb": ["prop1_p", "prop1_q", "rudder_p", "rudder_q"],
                "spark": ["prop1_p", "prop1_q", "rudder_p", "rudder_q"],
            },
            multiple_aggregates_not_aliased: {
                "duckdb": [
                    '"prop1_avg(price)"',
                    '"prop1_max(quality)"',
                    '"rudder_avg(price)"',
                    '"rudder_max(quality)"',
                ],
                "spark": [
                    "`prop1_avg(price)`",
                    "`prop1_max(quality)`",
                    "`rudder_avg(price)`",
                    "`rudder_max(quality)`",
                ],
            },
            multiple_aggregates_not_aliased_with_quoted_identifier_spark: {
                "spark": [
                    "`prop1_avg(PrIcE)`",
                    "`prop1_max(quality)`",
                    "`rudder_avg(PrIcE)`",
                    "`rudder_max(quality)`",
                ],
            },
            multiple_aggregates_not_aliased_with_quoted_identifier_duckdb: {
                "duckdb": [
                    '"prop1_avg(PrIcE)"',
                    '"prop1_max(quality)"',
                    '"rudder_avg(PrIcE)"',
                    '"rudder_max(quality)"',
                ],
            },
            two_in_clauses_duckdb: {
                "duckdb": [
                    '"2000_NL_total"',
                    '"2000_NL_count"',
                    '"2000_US_total"',
                    '"2000_US_count"',
                    '"2010_NL_total"',
                    '"2010_NL_count"',
                    '"2010_US_total"',
                    '"2010_US_count"',
                ],
            },
            three_in_clauses_duckdb: {
                "duckdb": [
                    '"2000_NL_Amsterdam_total"',
                    '"2000_NL_Amsterdam_count"',
                    '"2000_NL_Seattle_total"',
                    '"2000_NL_Seattle_count"',
                    '"2000_US_Amsterdam_total"',
                    '"2000_US_Amsterdam_count"',
                    '"2000_US_Seattle_total"',
                    '"2000_US_Seattle_count"',
                    '"2010_NL_Amsterdam_total"',
                    '"2010_NL_Amsterdam_count"',
                    '"2010_NL_Seattle_total"',
                    '"2010_NL_Seattle_count"',
                    '"2010_US_Amsterdam_total"',
                    '"2010_US_Amsterdam_count"',
                    '"2010_US_Seattle_total"',
                    '"2010_US_Seattle_count"',
                ],
            },
        }

        for query, dialect_columns in query_to_column_names.items():
            for dialect, expected_columns in dialect_columns.items():
                with self.subTest(f"Testing query '{query}' for dialect {dialect}"):
                    expr = parse_one(query, read=dialect)
                    columns = expr.args["from_"].this.args["pivots"][0].args["columns"]
                    self.assertEqual(
                        expected_columns, [col.sql(dialect=dialect) for col in columns]
                    )

    def test_parse_nested(self):
        def warn_over_threshold(query: str, max_threshold: float = 0.2):
            now = time.time()
            ast = parse_one(query)
            end = time.time() - now

            self.assertIsNotNone(ast)
            if end >= max_threshold:
                parser_logger.warning(
                    f"Query {query[:100]}... surpassed the time threshold of {max_threshold} seconds"
                )

        warn_over_threshold("SELECT * FROM a " + ("LEFT JOIN b ON a.id = b.id " * 38))
        warn_over_threshold("SELECT * FROM a " + ("LEFT JOIN UNNEST(ARRAY[]) " * 15))
        warn_over_threshold("SELECT * FROM a " + ("OUTER APPLY (SELECT * FROM b) " * 30))
        warn_over_threshold("SELECT * FROM a " + ("NATURAL FULL OUTER JOIN x " * 30))

    def test_parse_properties(self):
        self.assertEqual(
            parse_one("create materialized table x").sql(), "CREATE MATERIALIZED TABLE x"
        )

    def test_parse_floats(self):
        self.assertTrue(parse_one("1. ").is_number)

    def test_parse_terse_coalesce(self):
        self.assertIsNotNone(parse_one("SELECT x ?? y FROM z").find(exp.Coalesce))
        self.assertEqual(
            parse_one("SELECT a, b ?? 'No Data' FROM z").sql(),
            "SELECT a, COALESCE(b, 'No Data') FROM z",
        )
        self.assertEqual(
            parse_one("SELECT a, b ?? c ?? 'No Data' FROM z").sql(),
            "SELECT a, COALESCE(COALESCE(b, c), 'No Data') FROM z",
        )

    def test_parse_intervals(self):
        ast = parse_one(
            "SELECT a FROM tbl WHERE a <= DATE '1998-12-01' - INTERVAL '71 days' GROUP BY b"
        )

        self.assertEqual(ast.find(exp.Interval).this.sql(), "'71'")
        self.assertEqual(ast.find(exp.Interval).unit.assert_is(exp.Var).sql(), "DAYS")

    def test_parse_concat_ws(self):
        ast = parse_one("CONCAT_WS(' ', 'John', 'Doe')")

        self.assertEqual(ast.sql(), "CONCAT_WS(' ', 'John', 'Doe')")
        self.assertEqual(ast.expressions[0].sql(), "' '")
        self.assertEqual(ast.expressions[1].sql(), "'John'")
        self.assertEqual(ast.expressions[2].sql(), "'Doe'")

        # Ensure we can parse without argument when error level is ignore
        ast = parse(
            "CONCAT_WS()",
            error_level=ErrorLevel.IGNORE,
        )
        self.assertEqual(ast[0].sql(), "CONCAT_WS()")

    def test_parse_drop_schema(self):
        for dialect in [None, "bigquery", "snowflake"]:
            with self.subTest(dialect):
                ast = parse_one("DROP SCHEMA catalog.schema", dialect=dialect)
                self.assertEqual(
                    ast,
                    exp.Drop(
                        this=exp.Table(
                            this=None,
                            db=exp.Identifier(this="schema", quoted=False),
                            catalog=exp.Identifier(this="catalog", quoted=False),
                        ),
                        kind="SCHEMA",
                    ),
                )
                self.assertEqual(ast.sql(dialect=dialect), "DROP SCHEMA catalog.schema")

    def test_parse_create_schema(self):
        for dialect in [None, "bigquery", "snowflake"]:
            with self.subTest(dialect):
                ast = parse_one("CREATE SCHEMA catalog.schema", dialect=dialect)
                self.assertEqual(
                    ast,
                    exp.Create(
                        this=exp.Table(
                            this=None,
                            db=exp.Identifier(this="schema", quoted=False),
                            catalog=exp.Identifier(this="catalog", quoted=False),
                        ),
                        kind="SCHEMA",
                    ),
                )
                self.assertEqual(ast.sql(dialect=dialect), "CREATE SCHEMA catalog.schema")

    def test_values_as_identifier(self):
        sql = "SELECT values FROM t WHERE values + 1 > x"
        for dialect in (
            "bigquery",
            "clickhouse",
            "duckdb",
            "postgres",
            "redshift",
            "snowflake",
        ):
            with self.subTest(dialect):
                self.assertEqual(parse_one(sql, dialect=dialect).sql(dialect=dialect), sql)

    def test_alter_set(self):
        sqls = [
            "ALTER TABLE tbl SET TBLPROPERTIES ('x'='1', 'Z'='2')",
            "ALTER TABLE tbl SET SERDE 'test' WITH SERDEPROPERTIES ('k'='v', 'kay'='vee')",
            "ALTER TABLE tbl SET SERDEPROPERTIES ('k'='v', 'kay'='vee')",
            "ALTER TABLE tbl SET LOCATION 'new_location'",
            "ALTER TABLE tbl SET FILEFORMAT file_format",
            "ALTER TABLE tbl SET TAGS ('tag1' = 't1', 'tag2' = 't2')",
        ]

        for dialect in (
            "hive",
            "spark2",
            "spark",
            "databricks",
        ):
            for sql in sqls:
                with self.subTest(f"Testing query '{sql}' for dialect {dialect}"):
                    self.assertEqual(parse_one(sql, dialect=dialect).sql(dialect=dialect), sql)

    def test_distinct_from(self):
        self.assertIsInstance(parse_one("a IS DISTINCT FROM b OR c IS DISTINCT FROM d"), exp.Or)

    def test_trailing_comments(self):
        expressions = parse(
            """
        select * from x;
        -- my comment
            """
        )

        self.assertEqual(
            ";\n".join(e.sql() for e in expressions), "SELECT * FROM x;\n/* my comment */"
        )

    def test_parse_prop_eq(self):
        self.assertIsInstance(parse_one("x(a := b and c)").expressions[0], exp.PropertyEQ)

    def test_collate(self):
        collates = [
            ('pg_catalog."default"', exp.Column),
            ('"en_DE"', exp.Identifier),
            ("LATIN1_GENERAL_BIN", exp.Var),
            ("'en'", exp.Literal),
        ]

        for collate_pair in collates:
            collate_node = parse_one(
                f"""SELECT * FROM t WHERE foo LIKE '%bar%' COLLATE {collate_pair[0]}"""
            ).find(exp.Collate)
            self.assertIsInstance(collate_node, exp.Collate)
            self.assertIsInstance(collate_node.expression, collate_pair[1])

    def test_drop_column(self):
        ast = parse_one("ALTER TABLE tbl DROP COLUMN col")
        self.assertEqual(len(list(ast.find_all(exp.Table))), 1)
        self.assertEqual(len(list(ast.find_all(exp.Column))), 1)

    def test_udf_meta(self):
        ast = parse_one("YEAR(a) /* sqlglot.anonymous */")
        self.assertIsInstance(ast, exp.Anonymous)

        # Meta flag is case sensitive
        ast = parse_one("YEAR(a) /* sqlglot.anONymous */")
        self.assertIsInstance(ast, exp.Year)

        # Incomplete or incorrect anonymous meta comments are not registered
        ast = parse_one("YEAR(a) /* sqlglot.anon */")
        self.assertIsInstance(ast, exp.Year)

    def test_token_position_meta(self):
        ast = parse_one(
            "SELECT a, b FROM test_schema.test_table_a UNION ALL SELECT c, d FROM test_catalog.test_schema.test_table_b"
        )
        for identifier in ast.find_all(exp.Identifier):
            self.assertEqual(set(identifier.meta), {"line", "col", "start", "end"})

        self.assertEqual(
            ast.this.args["from_"].this.args["this"].meta,
            {"line": 1, "col": 41, "start": 29, "end": 40},
        )
        self.assertEqual(
            ast.this.args["from_"].this.args["db"].meta,
            {"line": 1, "col": 28, "start": 17, "end": 27},
        )
        self.assertEqual(
            ast.expression.args["from_"].this.args["this"].meta,
            {"line": 1, "col": 106, "start": 94, "end": 105},
        )
        self.assertEqual(
            ast.expression.args["from_"].this.args["db"].meta,
            {"line": 1, "col": 93, "start": 82, "end": 92},
        )
        self.assertEqual(
            ast.expression.args["from_"].this.args["catalog"].meta,
            {"line": 1, "col": 81, "start": 69, "end": 80},
        )

        ast = parse_one("SELECT FOO()")
        self.assertEqual(ast.find(exp.Anonymous).meta, {"line": 1, "col": 10, "start": 7, "end": 9})

        ast = parse_one("SELECT * FROM t")
        self.assertEqual(ast.find(exp.Star).meta, {"line": 1, "col": 8, "start": 7, "end": 7})

        ast = parse_one("SELECT t.* FROM t")
        self.assertEqual(ast.find(exp.Star).meta, {"line": 1, "col": 10, "start": 9, "end": 9})

        ast = parse_one("SELECT 1")
        self.assertEqual(ast.find(exp.Literal).meta, {"line": 1, "col": 8, "start": 7, "end": 7})

    def test_quoted_identifier_meta(self):
        sql = 'SELECT "a" FROM "test_schema"."test_table_a"'
        ast = parse_one(sql)

        db_meta = ast.args["from_"].this.args["db"].meta
        self.assertEqual(sql[db_meta["start"] : db_meta["end"] + 1], '"test_schema"')

        table_meta = ast.args["from_"].this.this.meta
        self.assertEqual(sql[table_meta["start"] : table_meta["end"] + 1], '"test_table_a"')

    def test_qualified_function(self):
        sql = "a.b.c.d.e.f.g.foo()"
        ast = parse_one(sql)
        assert not any(isinstance(n, exp.Column) for n in ast.walk())
        assert len(list(ast.find_all(exp.Dot))) == 7

    def test_pivot_missing_agg_func(self):
        with self.assertRaises(ParseError) as ctx:
            parse_one("select * from tbl pivot(col1 for col2 in (val1, val1))")

        self.assertIn("Expecting an aggregation function in PIVOT", str(ctx.exception))

    def test_multiple_query_modifiers(self):
        sql = "SELECT * FROM a WHERE b = 'true' AND c > 50 WHERE c = 'false'"

        with self.assertRaises(ParseError) as ctx:
            parse_one(sql)

        self.assertIn("Found multiple 'WHERE' clauses. Line 1, Col: 49.", str(ctx.exception))

        self.assertEqual(
            parse_one(sql, error_level=ErrorLevel.IGNORE).sql(),
            "SELECT * FROM a WHERE c = 'false'",
        )

    def test_parse_into_grant_principal(self):
        self.assertIsInstance(parse_one("ROLE blah", into=exp.GrantPrincipal), exp.GrantPrincipal)
        self.assertIsInstance(parse_one("GROUP blah", into=exp.GrantPrincipal), exp.GrantPrincipal)
        self.assertIsInstance(parse_one("blah", into=exp.GrantPrincipal), exp.GrantPrincipal)
        self.assertIsInstance(
            parse_one("ROLE `blah`", into=exp.GrantPrincipal, dialect="databricks"),
            exp.GrantPrincipal,
        )
        self.assertEqual(
            parse_one("ROLE `blah`", into=exp.GrantPrincipal, dialect="databricks").sql(
                dialect="databricks"
            ),
            "ROLE `blah`",
        )

    def test_parse_into_grant_privilege(self):
        self.assertIsInstance(parse_one("SELECT", into=exp.GrantPrivilege), exp.GrantPrivilege)
        self.assertIsInstance(
            parse_one("ALL PRIVILEGES", into=exp.GrantPrivilege), exp.GrantPrivilege
        )
